Model averaging

Why choose just one model for prediction?

Elizabeth King
Kevin Middleton

Model averaging

  • Fit several candidate models
  • Don’t choose the “best” and pretend it was the only model you fit
    • Combine posterior distribution from models
    • Analysis of parameter estimates or prediction for out-of-sample data
    • Weight each by their relative support

Model weights

  • WAIC (lecture 5.2)
  • PSIS-LOO-CV (lecture 5.3)
  • Stacking of means (Yao et al. 2018) (preferred)
  • Bayesian model averaging (BMA)
  • Pseudo-BMA with Bayesian bootstrap (2nd preferred)
  • Pseudo-BMA without Bayesian bootstrap

Description in the loo package

Five models for energy expenditure

Five models for energy expenditure

fm1 <- brm(Energy ~ 1, data = M,
           prior = prior(normal(0, 3), class = Intercept), iter = 2e4,
           refresh = 0, seed = 3476283)
fm2 <- brm(Energy ~ Caste, data = M,
           prior = prior(normal(0, 3), class = b), iter = 2e4,
           refresh = 0, seed = 312379)
fm3 <- brm(Energy ~ Mass, data = M,
           prior = prior(normal(0, 3), class = b), iter = 2e4,
           refresh = 0, seed = 12365864)
fm4 <- brm(Energy ~ Mass + Caste, data = M,
           prior = prior(normal(0, 3), class = b), iter = 2e4,
           refresh = 0, seed = 8873542)
fm5 <- brm(Energy ~ Mass * Caste, data = M,
           prior = prior(normal(0, 3), class = b), iter = 2e4,
           refresh = 0, seed = 612356)

Comparison of model weights

mw1 <- model_weights(fm1, fm2, fm3, fm4, fm5, weights = "waic")
mw2 <- model_weights(fm1, fm2, fm3, fm4, fm5, weights = "loo")
mw3 <- model_weights(fm1, fm2, fm3, fm4, fm5, weights = "stacking")
mw4 <- model_weights(fm1, fm2, fm3, fm4, fm5, weights = "pseudobma",
                     BB = FALSE)
mw5 <- model_weights(fm1, fm2, fm3, fm4, fm5, weights = "pseudobma")

tibble(Model = paste0("fm", 1:5),
       WAIC = mw1,
       LOO = mw2,
       Stacking = mw3,
       PseudoBMA = mw4,
       `PseudoBMA + BB` = mw5) |> 
  knitr::kable(digits = 2)

Comparison of model weights

Model WAIC LOO Stacking PseudoBMA PseudoBMA + BB
fm1 0.00 0.00 0.06 0.00 0.02
fm2 0.00 0.00 0.00 0.00 0.01
fm3 0.04 0.04 0.04 0.04 0.14
fm4 0.50 0.50 0.56 0.50 0.42
fm5 0.46 0.45 0.35 0.45 0.41

Model weights by stacking

Two kinds of posterior intervals

  1. HDIs for the parameter estimates
    • Credible ranges for expected values
    • posterior_epred() (Lecture 3.4)
  2. HDIs for new or observed values (“posterior predictive distribution”)
    • Include the uncertainty (\(\sigma\))
    • Wider than expected values intervals
    • posterior_predict()

Posterior prediction for model 4

pp_fm4 <- crossing(Mass = seq(3.8, 5.3, length.out = 200),
                   Caste = levels(M$Caste))

pp <- posterior_predict(fm4, newdata = pp_fm4) |> 
  as.data.frame()
str(pp)
'data.frame':   40000 obs. of  400 variables:
 $ V1  : num  2.8 3.35 3.56 3.16 3.32 ...
 $ V2  : num  3.36 3.76 3.5 4.27 3.13 ...
 $ V3  : num  3.27 3.07 3.13 3.23 3.27 ...
 $ V4  : num  2.93 3.85 3.34 3.57 4.18 ...
 $ V5  : num  3.59 3.7 3.99 2.66 2.79 ...
 $ V6  : num  3.69 3.59 3.06 4.02 3.44 ...
 $ V7  : num  3.02 3.48 3.14 3.09 3.35 ...
 $ V8  : num  3.98 3.85 3.92 4.25 3.38 ...
 $ V9  : num  3.54 3.02 3.6 3.47 2.72 ...
 $ V10 : num  4.02 3.74 2.85 4.05 3.07 ...
 $ V11 : num  2.41 3.56 2.92 3.75 3.32 ...
 $ V12 : num  3.68 3.72 3.74 3.97 3.36 ...
 $ V13 : num  2.89 3.46 3.23 3.45 3.19 ...
 $ V14 : num  4.29 3.79 3.72 3.9 3.72 ...
 $ V15 : num  3.19 3.38 3.55 3.45 2.79 ...
 $ V16 : num  3.78 4.15 3.12 3.75 3.39 ...
 $ V17 : num  2.31 3.14 3.96 3.54 3.3 ...
 $ V18 : num  3.75 3.92 4.08 4.46 3.88 ...
 $ V19 : num  2.7 3.06 3.51 2.95 2.35 ...
 $ V20 : num  3.43 4.02 3.89 3.86 3.73 ...
 $ V21 : num  3.41 3.53 3.72 3.37 3.03 ...
 $ V22 : num  3.69 3.67 3.52 3.63 4.12 ...
 $ V23 : num  3.3 3.14 3.39 3.12 2.94 ...
 $ V24 : num  3.79 3.84 3.86 4.13 3.34 ...
 $ V25 : num  3.49 3.64 3.17 2.5 3.32 ...
 $ V26 : num  3.11 4.01 3.67 3.99 3.65 ...
 $ V27 : num  3.23 3.91 3.97 3.31 3.33 ...
 $ V28 : num  3.01 3.48 3.67 3.89 3.71 ...
 $ V29 : num  4.25 3.59 3.73 2.64 3.09 ...
 $ V30 : num  3.89 3.29 3.46 3.49 3.34 ...
 $ V31 : num  3.62 3.62 3.94 3.3 3.17 ...
 $ V32 : num  3.39 3.6 3.83 4.01 3.94 ...
 $ V33 : num  3.58 2.97 3.83 3.17 3.19 ...
 $ V34 : num  3.31 3.8 3.57 4.32 2.99 ...
 $ V35 : num  3.51 3.58 3.28 2.84 3.53 ...
 $ V36 : num  3.95 3.34 3.5 3.86 3.94 ...
 $ V37 : num  3.93 3.18 3.53 2.93 3.85 ...
 $ V38 : num  3.35 3.61 3.8 3.74 3.39 ...
 $ V39 : num  3.57 3.71 3.72 3.33 3.39 ...
 $ V40 : num  3.94 4 3.89 4.49 3.82 ...
 $ V41 : num  2.67 3.21 4.08 3.79 3.35 ...
 $ V42 : num  3.55 3.19 3.29 3.62 3.5 ...
 $ V43 : num  3.17 3.47 3.42 3.27 3.24 ...
 $ V44 : num  3.83 3.97 3.16 4.1 4.09 ...
 $ V45 : num  3.8 3.47 3.71 3.21 2.94 ...
 $ V46 : num  3.48 3.98 4.25 3.62 3.97 ...
 $ V47 : num  3.5 3.45 3.67 3.29 3.63 ...
 $ V48 : num  3.68 4.03 3.85 3.8 3.97 ...
 $ V49 : num  3.16 3.63 3.49 3.65 2.85 ...
 $ V50 : num  3.3 4.61 3.85 3.37 3.85 ...
 $ V51 : num  3.67 3.65 4.19 3.07 3.82 ...
 $ V52 : num  3.9 3.75 3.53 3.83 3.96 ...
 $ V53 : num  3.46 3.93 3.68 3.55 3.23 ...
 $ V54 : num  4.22 4.01 2.88 4.04 4.03 ...
 $ V55 : num  2.77 3.65 3.32 3.85 3.77 ...
 $ V56 : num  3.55 4 3.26 3.2 3.52 ...
 $ V57 : num  2.85 3.3 4.11 3.2 2.62 ...
 $ V58 : num  3.04 3.84 3.68 4.01 3.94 ...
 $ V59 : num  3 4.09 3.29 3.46 3.11 ...
 $ V60 : num  3.73 3.88 4.2 3.82 4.03 ...
 $ V61 : num  2.88 3.56 3.55 3.48 3.22 ...
 $ V62 : num  4.26 4.02 3.46 3.45 4.09 ...
 $ V63 : num  2.51 3.59 4.42 3.08 3.56 ...
 $ V64 : num  3.74 3.88 3.96 4.1 3.38 ...
 $ V65 : num  3.98 3.26 3.67 3.19 3.33 ...
 $ V66 : num  3.52 3.8 3.86 3.94 3.92 ...
 $ V67 : num  3.9 3.7 3.45 3.33 2.99 ...
 $ V68 : num  4.4 3.84 3.97 3.83 4.13 ...
 $ V69 : num  2.98 4.06 4.16 4.07 3.61 ...
 $ V70 : num  4.34 4.4 3.49 4.07 3.35 ...
 $ V71 : num  3.18 3.31 3.33 3.75 3.68 ...
 $ V72 : num  3.09 3.46 3.37 3.98 3.39 ...
 $ V73 : num  3.37 3.66 4.56 3.61 2.8 ...
 $ V74 : num  4.06 4.1 3.43 4.38 3.5 ...
 $ V75 : num  3.08 3.47 3.59 3.43 3.45 ...
 $ V76 : num  4.6 4.14 3.65 4.14 3.85 ...
 $ V77 : num  3.79 3.99 3.45 3.79 2.75 ...
 $ V78 : num  3.5 3.44 4.5 4.07 4.07 ...
 $ V79 : num  3.6 3.52 3.2 3.26 3.43 ...
 $ V80 : num  3.25 4.45 3.95 3.74 3.7 ...
 $ V81 : num  2.69 3.91 3.4 3.85 2.42 ...
 $ V82 : num  3.99 3.96 3.66 3.74 3.51 ...
 $ V83 : num  3.69 3.3 3.73 3.36 3.42 ...
 $ V84 : num  4.18 4.09 3.8 3.9 3.5 ...
 $ V85 : num  3.66 3.98 4.07 3.39 3.87 ...
 $ V86 : num  3.69 3.84 3.84 4.24 3.93 ...
 $ V87 : num  3.26 4.12 3.74 2.99 3.27 ...
 $ V88 : num  3.45 3.83 4.43 4.19 3.85 ...
 $ V89 : num  3.8 4.2 3.68 3.79 2.26 ...
 $ V90 : num  3.57 3.45 4.03 3.96 3.88 ...
 $ V91 : num  3.1 3.76 4.03 3.73 2.86 ...
 $ V92 : num  4.17 4.2 3.94 4.22 3.7 ...
 $ V93 : num  3.72 3.61 4.72 3.58 3.36 ...
 $ V94 : num  3.56 4.35 3.89 4.18 3.64 ...
 $ V95 : num  3.72 3.64 4.23 3.46 3.22 ...
 $ V96 : num  4.86 3.86 3.53 4.28 3.65 ...
 $ V97 : num  3.47 3.67 4.22 3.89 3.37 ...
 $ V98 : num  3.75 4.22 3.5 4.09 3.85 ...
 $ V99 : num  3.47 3.98 3.71 3.5 3.67 ...
  [list output truncated]

Posterior prediction for model 4

pp_fm4 <- pp_fm4 |> 
  mutate(Q50 = apply(pp, MARGIN = 2, FUN = median),
         Q5.5 = apply(pp, MARGIN = 2, FUN = quantile, prob = 0.055),
         Q94.5 = apply(pp, MARGIN = 2, FUN = quantile, prob = 0.945))
head(pp_fm4)
# A tibble: 6 × 5
   Mass Caste       Q50  Q5.5 Q94.5
  <dbl> <chr>     <dbl> <dbl> <dbl>
1  3.8  NonWorker  3.30  2.69  3.92
2  3.8  Worker     3.69  3.16  4.23
3  3.81 NonWorker  3.31  2.69  3.92
4  3.81 Worker     3.70  3.16  4.23
5  3.82 NonWorker  3.31  2.70  3.93
6  3.82 Worker     3.70  3.17  4.24

Posterior prediction for model 4

Posterior prediction for model 5

Model averaging by stacking

Combine all models by their relative support

pp_mod_avg <- crossing(Mass = seq(3.8, 5.3, length.out = 200),
                       Caste = levels(M$Caste))

pp <- pp_average(fm1, fm2, fm3, fm4, fm5,
                 newdata = pp_mod_avg,
                 weights = "stacking",
                 method = "posterior_predict",
                 probs = c(0.055, 0.5, 0.945))
attr(pp, "weights") |> round(3)
  fm1   fm2   fm3   fm4   fm5 
0.057 0.000 0.038 0.556 0.350 
attr(pp, "ndraws")
  fm1   fm2   fm3   fm4   fm5 
 2267     0  1521 22231 13981 

Join posterior predictions to new data

pp_mod_avg <- bind_cols(pp_mod_avg, pp)
head(pp_mod_avg)
# A tibble: 6 × 7
   Mass Caste     Estimate Est.Error  Q5.5   Q50 Q94.5
  <dbl> <chr>        <dbl>     <dbl> <dbl> <dbl> <dbl>
1  3.8  NonWorker     3.45     0.473  2.73  3.43  4.25
2  3.8  Worker        3.70     0.364  3.14  3.70  4.30
3  3.81 NonWorker     3.45     0.467  2.74  3.43  4.24
4  3.81 Worker        3.71     0.363  3.15  3.70  4.30
5  3.82 NonWorker     3.46     0.470  2.75  3.44  4.26
6  3.82 Worker        3.72     0.362  3.16  3.71  4.31

Plotting model average

Comparison

Bayesian workflow

Also Gelman et al. (2020) and Gabry et al. (Gabry et al. 2019)

  1. Model specification
  2. Prior specification
  3. Prior predictive simulation / check
  4. Sampling
  5. Diagnostics
  6. Posterior predictive simulation
  7. Summarizing the posterior
  8. Model comparison and averaging

References

Gabry, J., D. Simpson, A. Vehtari, M. Betancourt, and A. Gelman. 2019. Visualization in Bayesian Workflow. J. R. Stat. Soc. Ser. A Stat. Soc. 182:389–402. Wiley.
Gelman, A., A. Vehtari, D. Simpson, C. C. Margossian, B. Carpenter, Y. Yao, L. Kennedy, J. Gabry, P.-C. Bürkner, and M. Modrák. 2020. Bayesian Workflow.
Yao, Y., A. Vehtari, D. Simpson, and A. Gelman. 2018. Using Stacking to Average Bayesian Predictive Distributions. Bayesian Analysis 13:917–1007. International Society for Bayesian Analysis.